# Change working directory
# Line 40 - MODEL DEPENDENCY EXAMINATION
# Line 306 - PURE IMPUTATION ESTIMATOR
# NOTE: For Additional ML algorithms for Appendix A see Line 1211 of 04_Analyses.R 

rm(list = ls()) 
library(tidyverse)
library(caret)
library(haven)
library(stringr)
library(survey)
library(splines)
library(rlang)
library(lmtest)
library(sandwich)
library(broom)
library(MASS)
library(matrixStats)
library(naniar)
library(VIM)
library(mice)
library(ggpubr)
library(haven)
library(splines)
library(survey)
library(rbw)
library(gbm)
library(xgboost)
library(BART)
library(rlang)
library(rsample)
library(mgcv)
source("zmisc.R")
set.seed(666)

# Set working directory
# setwd("~/Dropbox (Harvard University)/Gov 2001 Rep Paper/R Scripts")
# setwd("~/Dropbox (Harvard University)/Gov_2001_Rep/R Scripts")


#############################################
##### (1) MODEL DEPENDENCY EXAMINATION #####
#############################################
mytheme <- theme_minimal(base_size = 16) + 
  theme(legend.position = "bottom",
        plot.title = element_text(hjust = 0.5),
        plot.caption = element_text(color = "grey30"),
        axis.title.y = element_text(margin = margin(t = 0, r = 10, b = 0, l = 0)))
theme_set(mytheme)


data <- read_dta("Datasets/Downes-Sechser-Appendix-C.dta")



mytheme <- theme_minimal(base_size = 16) + 
  theme(legend.position = "bottom",
        plot.title = element_text(hjust = 0.5),
        plot.caption = element_text(color = "grey30"),
        axis.title.y = element_text(margin = margin(t = 0, r = 10, b = 0, l = 0)))
theme_set(mytheme)

# Variables and year dummies for each of the world war years
variables <- vars(majmaj_ks, majmin_ks, minmaj_ks, capshare_a_ks, contig_ks, swt_dyad, tau_lead_a, tau_lead_b,
                  territory_ks, government_ks, policy_ks, other_ks)

schultz_dems <- vars(democ_a_ks, democ_b_ks, demdem_ks)

year_dummies <- vars(dummy1914_ks, dummy1915_ks, dummy1916_ks, dummy1917_ks, dummy1918_ks, dummy1939_ks,
                     dummy1940_ks, dummy1941_ks, dummy1942_ks, dummy1943_ks, dummy1945_ks)            

ww_rhs <- map(c(variables, year_dummies), get_expr) %>%
  reduce(~ expr(!!.x + !!.y))

no_ww_rhs <- map(variables, get_expr) %>%
  reduce(~ expr(!!.x + !!.y))

schultz_dems_rhs <- map(schultz_dems, get_expr) %>%
  reduce(~ expr(!!.x + !!.y))

# Schultz democracy measure attempt: no world wars

d <- schultz_df <- filter(data, worldwar_ks == 0) %>% dplyr::select(!!!variables, !!!schultz_dems, failure_100_ks) %>% drop_na() %>%
  mutate(treat_num = as.double(democ_a_ks), treat = factor(democ_a_ks, labels = c("no", "yes")), outcome = factor(failure_100_ks, labels = c("no", "yes")),
         outcome_num = as.double(failure_100_ks))

names(schultz_df)
nrow(schultz_df) #198 non world war observations

factual <- counter <- schultz_df <- d
counter$treat_num <- 1-counter$treat_num

# Convex hull
my.result <- whatif(data = dplyr::select(factual, -c(treat, outcome)), cfact = dplyr::select(counter, -c(treat, outcome)))
sum(as.numeric(my.result$in.hull)) 


# Exploration of potential model dependency
#################################################################
############# ATE via debiased machine learning ###########
#################################################################
## NB using random forest for propensity score and outcome models

# excluded demdem_ks as i dont think should include varibae expanded interaction terms
a_formula <- as.formula(expr(treat ~ !!no_ww_rhs + democ_b_ks))

y_formula <- as.formula(expr(outcome ~ !!no_ww_rhs + democ_b_ks + democ_a_ks))


k <- 1
K <- 5
whole_folds <- createFolds(d$outcome, k = K) 
predicted_list <- vector(mode = "list", K)

###############################################
#  Use each fold as an estimation sample once 
################################################
for (k in 1:K) {
  
  cat("fold ", k, "\n")
  aux <- d[-whole_folds[[k]], ]
  main <- d[whole_folds[[k]], ]
  
  ## Data matrices for train() and predict()
  # Auxiliary sample models
  aux_X <-  model.matrix(y_formula, data = aux)[, -1]
  
  # Counterfactual main sample models
  main0 <- main1 <- main
  main0$treat <- 0   ## check name of variable required for RF prediction 
  main1$treat <- 1
  main0$treatyes <- 0
  main1$treatyes <- 1
  main_X_a <-  model.matrix(a_formula, data = main)[, -1]
  
  main_X_a0 <-  model.matrix(y_formula, data = main0)[, -1]
  main_X_a1 <-  model.matrix(y_formula, data = main1)[, -1]
  
  
  myFolds <- createFolds(aux$outcome, k = 5)
  
  ###############################
  # Propensity score models
  ###############################
  aControl <- trainControl(
    summaryFunction = mnLogLoss, 
    number = 5,
    method = "cv",
    classProbs = TRUE,
    savePredictions = TRUE,
    index = myFolds
  )
  
  # Random forest
  cat(" ps: rf", "\n")
  rf_ps <- train(
    a_formula,
    data = aux,
    method = "ranger", 
    trControl = aControl, 
    tuneLength = 3
  )
  
  main$rf_ps0 <- predict(rf_ps$finalModel, data = main, type = "response")$predictions[, "no"]
  main$rf_ps1 <- predict(rf_ps$finalModel, data = main, type = "response")$predictions[, "yes"]
  
  #  Support Vector Machine
  cat(" ps: svm", "\n")
  svm_ps <- train(
    a_formula,
    data = aux,
    method = "svmRadial", 
    trControl = aControl, 
    prob.model = TRUE,
    tuneGrid = expand.grid(sigma = seq(0.1, 1, 0.1), C = 1))
  
  main$svm_ps1  <- kernlab::predict(svm_ps$finalModel, main_X_a, type = "probabilities")[, "yes"]
  main$svm_ps0 <- kernlab::predict(svm_ps$finalModel, main_X_a, type = "probabilities")[, "no"]
  
  
  ###############################
  # Outcome models
  ###############################
  yControl <- trainControl(method = "cv",
                           index = myFolds,
                           summaryFunction = mnLogLoss,
                           classProbs = TRUE)
  
  #  Random Forest
  cat(" outcome: rf", "\n")
  rf_y <- train(aux_X, aux$outcome, method = "ranger",
                trControl = yControl, tuneLength = 3)
  
  
  main$rf_y0 <- predict(rf_y$finalModel, data = main0, type = "response")$predictions[, "yes"]
  main$rf_y1 <- predict(rf_y$finalModel, data = main1, type = "response")$predictions[, "yes"]
  
  
  # Support Vector Machine
  cat(" outcome: svm", "\n")
  svm_y <- train(
    y_formula,
    data = aux,
    method = "svmRadial",
    trControl = yControl,
    tuneGrid = expand.grid(sigma = c(0.005, 0.01), C = c(0.5, 1, 5, 10)))
  
  main$svm_y0 <- kernlab::predict(svm_y$finalModel, main_X_a0, type = "probabilities")[, "yes"]
  main$svm_y1 <- kernlab::predict(svm_y$finalModel, main_X_a1, type = "probabilities")[, "yes"]
  
  
  predicted_list[[k]] <- main
  
}

predicted_df <- reduce(predicted_list, bind_rows)

### CONSTRUCT PSEUDO OUTCOMES
methods <- c("rf", "svm")

for(psmod in methods){
  for(ymod in methods){
    
    predicted_df <- predicted_df %>%
      mutate(!!sym(str_c(psmod, "_", ymod, "_y0")) := !!sym(str_c(ymod, "_y0")) + (1-treat_num)/trim(!!sym(str_c(psmod, "_ps0"))) * (outcome_num - !!sym(str_c(ymod, "_y0"))),
             !!sym(str_c(psmod, "_", ymod, "_y1")) := !!sym(str_c(ymod, "_y1")) + treat_num/trim(!!sym(str_c(psmod, "_ps1"))) * (outcome_num - !!sym(str_c(ymod, "_y1"))),
             !!sym(str_c(psmod, "_", ymod, "_TE")) := !!sym(str_c(psmod, "_", ymod, "_y1")) - !!sym(str_c(psmod, "_", ymod, "_y0")))
    
  }
}


holder_ATE <- vector(mode = "list")

design <- suppressWarnings(svydesign(ids = ~ 1, data = predicted_df))

for(psmod in methods){
  for(ymod in methods){
    nameATEa <- sym(paste0(psmod, "_", ymod, "_TE"))
    modATEa <- sym(paste0(psmod, "_", ymod, "_modTE"))
    
    assign(as_string(modATEa), eval(expr(svyglm(!!nameATEa ~ 1, design = design))))
    
    df <- tibble("educ2" = (c("overall_ate")),
                 "model" = paste0(psmod, "_", ymod),
                 "est" = map(eval(expr(list(!!modATEa))), "coefficients") %>% unlist(),
                 "se" = sqrt(map(eval(expr(list(!!modATEa))), vcov) %>% unlist())) 
    
    holder <- sym(paste0(psmod, "_", ymod, "_holder_ate"))
    holder_ATE[[""]] <- df
  }}

# Plot results
final_output <- reduce(holder_ATE, bind_rows) 


### OLD STUFF ###
# Original model ATE
tab_5_form_no_ww <- as.formula(expr(failure_100_ks ~ democ_a_ks + democ_b_ks + demdem_ks + !!no_ww_rhs))
mod2 <- glm(tab_5_form_no_ww, data = d, family = binomial())

coefs <- mod2$coeff %>% as.data.frame() %>% t() %>% as.data.frame() %>% dplyr::select(sort(names(.)))
dem1 <- predicted_df %>% dplyr::select(!!!variables, !!!schultz_dems) %>% mutate(democ_a_ks=1, "(Intercept)" = 1) %>%dplyr::select(sort(names(.)))

p1 <- as.matrix(dem1) %*% t(unname(as.matrix(coefs)))

dem0 <- predicted_df %>% dplyr::select(!!!variables, !!!schultz_dems) %>% mutate(democ_a_ks=0, "(Intercept)" = 1) %>%dplyr::select(sort(names(.)))
p0 <- as.matrix(dem0) %*% t(unname(as.matrix(coefs)))

diff <- p1-p0


##### Interpolation vs extrapolation
### Original model factual predicted values
dem_og <- predicted_df %>% dplyr::select(!!!variables, !!!schultz_dems) %>% mutate("(Intercept)" = 1) %>%dplyr::select(sort(names(.)))
predicted_factual_og <- plogis(as.matrix(dem_og) %*% t(unname(as.matrix(coefs))))

### Original model counterfactual values
dem_count1 <- predicted_df %>% dplyr::select(!!!variables, !!!schultz_dems) %>% mutate(democ_a_ks= 1 - democ_a_ks, "(Intercept)" = 1) %>%dplyr::select(sort(names(.)))
predicted_counter_og <- plogis(as.matrix(dem_count1) %*% t(unname(as.matrix(coefs))))

### Random forest factual predicted values
counter_rf <- factual_rf <- d
counter_rf$democ_a_ks <- 1 - counter_rf$democ_a_ks

predicted_factual_rf <-  predict(rf_y$finalModel, data = factual_rf, type = "response")$predictions[, "yes"]
predicted_counter_rf <- predict(rf_y$finalModel, data = counter_rf, type = "response")$predictions[, "yes"]


output <- data.frame(predicted_factual_og, predicted_counter_og, predicted_factual_rf, predicted_counter_rf)


ggplot(output, aes(predicted_factual_og, predicted_factual_rf)) +
  geom_point() + 
  xlab("Predicted Factual \n(DS Model)") +
  ylab("Predicted Factual \n(Machine Learning Model)")
ggsave(path = "Figures", file = "Factual.png", width = 6, height = 5)

ggplot(output, aes(predicted_counter_og, predicted_counter_rf)) +
  geom_point() +
  xlab("Predicted Counterfactual \n(DS Model)") +
  ylab("Predicted Counterfactual \n(Machine Learning Model)")
ggsave(path = "Figures", file = "Counterfactual.png", width = 6, height = 5)


##########################################
#### (2) PURE IMPUTATION ESTIMATOR ###
########################################
# create treatment and outcome formulas for fitting GBM models
x_vars <- vars(major_a, alliance_gg, cinc_a, cinc_b, contig_ks) 
dem_variables <- vars(democ_a, democ_b) # don't create interactive varibale for ml
z_vars_b <- vars(recruit)
z_vars_c <- vars(female_suffrage, milex_quintile, reg_power_rank, e_peaveduc) #TO CONSIDER : PUT CINCA AND MAJOR IN Z

x_rhs <- map(x_vars, get_expr) %>% reduce(~ expr(!!.x + !!.y))
z_rhs <- map(c(z_vars_b, z_vars_c), get_expr) %>% reduce(~ expr(!!.x + !!.y))
dem_variables_rhs <- map(dem_variables, get_expr) %>% reduce(~ expr(!!.x + !!.y))

# formulas
m_form <- as.formula(expr(mediator_factor ~ !!x_rhs + !!z_rhs + !!dem_variables_rhs))
y_form <- as.formula(expr(compliance ~ mediator_factor + !!x_rhs + !!z_rhs + !!dem_variables_rhs))


### Estimators with multiple imputation
# load data
# WITH MIDS CONTROL GROUP
# imputation_all  <- readRDS("imputed_full.RDS") %>% 
#   map(., ~ mutate(.x, mediator = mediator,
#                   mediator_factor = factor(mediator, labels = c("no", "yes")),
#                   outcome = factor(if_else(compliance %in% c(2,1), 1, 0),  labels = c("yes", "no"))))

#WITH CASE CONTROL GROUP
imputation_all  <- readRDS("full_sample_case_control_imputed.RDS") %>% 
  map(., ~ mutate(.x, mediator = mediator,
                  mediator_factor = factor(mediator, labels = c("no", "yes")),
                  outcome = factor(if_else(compliance %in% c(2,1), 1, 0),  labels = c("yes", "no"))))


# full_sample_case_control_imputed <- readRDS("full_sample_case_control_imputed.RDS")[[1]]


  
  cat("imputed sample ", i, "\n")
  i=1
  d <- imputation_all[[i]]
  
  k <- 1
  K <- 5
  whole_folds <- createFolds(d$outcome, k = K) 
  predicted_list <- vector(mode = "list", K)
  
  for (k in 1:K) {
    
    cat("fold ", k, "\n")
    aux <- d[-whole_folds[[k]], ]
    main <- d[whole_folds[[k]], ]
    
    ## Data matrices for train() and predict()
    # Auxiliary sample models
    aux_X <-  model.matrix(y_form, data = aux)[, -1] #outcome model only 
    aux_X_outcome <- dplyr::select(filter(aux, mediator == 1), outcome)
    
    # Counterfactual main sample models
    main0 <- main1 <- main
    #outcome models prediction dataset
    main0$mediator_factoryes  <- 0   ## check name of variable required for RF prediction 
    main1$mediator_factoryes  <- 1
    
    
    main_X_a0 <-  model.matrix(y_form, data = main0)[, -1]
    main_X_a1 <-  model.matrix(y_form, data = main1)[, -1]
    
    myFolds <- createFolds(aux$outcome, k = 5)
    
    ###############################
    # Propensity score models - probabiliyt p of receiving mediator and 1-p
    ###############################
    aControl <- trainControl(
      summaryFunction = mnLogLoss, 
      number = 5,
      method = "cv",
      classProbs = TRUE,
      savePredictions = TRUE,
      index = myFolds
    )
    
    # Random forest
    cat(" ps: rf", "\n")
    rf_ps <- train(
      m_form,
      data = aux,
      method = "ranger", 
      trControl = aControl, 
      tuneLength = 3
    )
    
    main$rf_ps0 <- predict(rf_ps$finalModel, data = main, type = "response")$predictions[, "no"]
    main$rf_ps1 <- predict(rf_ps$finalModel, data = main, type = "response")$predictions[, "yes"]
    
    ###############################
    # Outcome models
    ###############################
    m1_folds <- createFolds(aux$outcome[aux$mediator==1], k = 5)
    yControl <- trainControl(method = "cv",
                             index = m1_folds,
                             summaryFunction = mnLogLoss,
                             classProbs = TRUE)
    
    #  Random Forest
    cat(" outcome: rf", "\n")
    rf_y <- train(aux_X, aux$outcome[aux$mediator==1], method = "ranger",
                  trControl = yControl, tuneLength = 3)
    
    main$rf_y0 <- predict(rf_y$finalModel, data = main0, type = "response")$predictions[, "yes"]
    main$rf_y1 <- predict(rf_y$finalModel, data = main1, type = "response")$predictions[, "yes"]
    
    
    predicted_list[[k]] <- main
  }
  
  
  predicted_df <- reduce(predicted_list, bind_rows) %>% filter(mediator == 1) %>% mutate(outcome_num = if_else(outcome == "yes", 1, 0))
  
  ### CONSTRUCT PSEUDO OUTCOMES
  methods <- c("rf")
  
  for(psmod in methods){
    for(ymod in methods){
      
      predicted_df <- predicted_df %>% 
        mutate(!!sym(str_c(psmod, "_", ymod, "_y0")) := !!sym(str_c(ymod, "_y0")) + (1-mediator)/trim(!!sym(str_c(psmod, "_ps0"))) * (outcome_num - !!sym(str_c(ymod, "_y0"))),
               !!sym(str_c(psmod, "_", ymod, "_y1")) := !!sym(str_c(ymod, "_y1")) + mediator/trim(!!sym(str_c(psmod, "_ps1"))) * (outcome_num - !!sym(str_c(ymod, "_y1"))),
               !!sym(str_c(psmod, "_", ymod, "_TE")) := !!sym(str_c(psmod, "_", ymod, "_y1")) - !!sym(str_c(psmod, "_", ymod, "_y0")))
      
    }
  }
  
  


outcome_formula <- as.formula(eval(expr(outcome ~ !!x_rhs + !!dem_variables_rhs)))
outcome_model <- eval(expr(glm(outcome_formula, data = predicted_df, family = binomial())))


mod = glm(outcome_num ~ contig_ks + major_a + alliance_gg + cinc_a + cinc_b  + democ_b + ns(democ_a, df = 3), data = predicted_df)

mod = glm(outcome_num ~ contig_ks + major_a + alliance_gg + cinc_a + cinc_b  + democ_b + democ_a, data = predicted_df)

mod = glm(outcome_num ~ ns(democ_a, df = 2) , data =predicted_df)


fitted <- predict(mod, se.fit = T)

plot <- predicted_df %>%
  mutate(estimate = plogis(fitted$fit),
         ci.min = plogis(fitted$fit - qnorm(.975) * fitted$se.fit),
         ci.max = plogis(fitted$fit + qnorm(.975) * fitted$se.fit)) %>%
  ggplot(aes(x = democ_a, y = estimate)) +
  geom_line() +
  geom_ribbon(aes(
    ymin = ci.min, ymax = ci.max),
    color = NA,
    alpha = .2) +
  theme_minimal(base_size = 16) + 
  scale_x_continuous(name = "Polity Score (Democracy = 10)") +
  scale_y_continuous(name = "Probability of Conflict Success") 
ggsave(path = "Figures", filename = "Par_imputation.jpg", width = 8, height = 6)
